import datetime
import tempfile
import os
from pathlib import Path
import traceback
from ray.rllib.algorithms.sac import SACConfig
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.policy.policy import PolicySpec

from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.utils.checkpoints import get_checkpoint_info
from ray.rllib.env import BaseEnv
from ray.rllib.evaluation import Episode, RolloutWorker
from ray.tune.registry import register_env
from ray.tune.logger import UnifiedLogger

from UnitCell_Environment.unitcell_environment.env.hier_paral_unitcell_environment import HierParalUnitCellEnvironment
from UnitCell_Environment.unitcell_environment.env.utils import ELEMENTS

def paral_env_creator(config):

    return HierParalUnitCellEnvironment(config)

def custom_log_creator(custom_path, custom_str):

    timestr = datetime.datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
    logdir_prefix = "{}_{}".format(custom_str, timestr)

    if not os.path.exists(custom_path):
        os.makedirs(custom_path)
    
    logdir = tempfile.mkdtemp(prefix=logdir_prefix, dir=custom_path)

    def logger_creator(config):
        return UnifiedLogger(config, logdir, loggers=None)

    return logger_creator, logdir


class CustomMetricsCallback(DefaultCallbacks):

    def on_episode_end(
        self,
        *,
        worker: RolloutWorker,
        base_env: BaseEnv,
        episode: Episode,
        env_index: int,
        **kwargs
    ):
        
        env = base_env.get_sub_environments()[0]
        
        if hasattr(env, "get_sub_environments"):
            env = env.get_sub_environments

        episode.custom_metrics["optimisation_step"] = env.optimisation_step
        episode.custom_metrics["optimisation_len"] = env.last_optimisation_len


def create_config(env_config, algo_config):

    shared_policy = True

    elements = ELEMENTS[env_config.get("comp_name", "SrTiO3x8")]

    policies = {}        

    register_env(env_config.get("env_name", "env_name"), lambda config: paral_env_creator(config))

    # we share the policy for all agents
    shared_policy = True

    if shared_policy:
        policies_to_train = ["2_stepsize"]
    else:
        policies_to_train = [f"2_stepsize_{e}" for e in elements]

    algo_class = SACConfig

    if shared_policy:

        policies["2_stepsize"] = PolicySpec()
        config = algo_class().multi_agent(policies=policies,
                policy_mapping_fn=(
                    lambda agent_id, episode, worker, **kw: (
                        agent_id.split("_")[0] + "_" + agent_id.split("_")[1])
                    ),
                policies_to_train=policies_to_train
                )
        
    else:
        
        for e in elements:
            policies[f"2_stepsize_{e}"] = PolicySpec()   
        config = algo_class().multi_agent(policies=policies,
                policy_mapping_fn=(
                    lambda agent_id, episode, worker, **kw: (
                        agent_id.split("_")[0] + "_" + agent_id.split("_")[1] + "_" + agent_id.split("_")[2])
                    ),
                policies_to_train=policies_to_train
                )

    config.callbacks(CustomMetricsCallback)
    config.update_from_dict(algo_config)
    config.environment(env=env_config.get("env_name", "env_name"), env_config=env_config)
    config.api_stack(enable_rl_module_and_learner=False, enable_env_runner_and_connector_v2=False)

    return config


def save_checkpoint(algo, logdir, iteration):
    
    checkpoint_dir = os.path.abspath(f"{logdir}/checkpoint_{format(iteration, '06d')}")
    Path(checkpoint_dir).mkdir(parents=True, exist_ok=True)
    algo.save_checkpoint(checkpoint_dir=checkpoint_dir)


def training_task(env_config, algo_config, iterations, checkpoint=None):

    logger_creator, logdir = custom_log_creator(os.path.expanduser("ray_results"), 
                                                f"MARL_{env_config.get('env_name', 'env_name')}")

    env_name = env_config["env_name"]

    if checkpoint is not None:

        register_env(env_name, lambda config: paral_env_creator(config))

        algo = Algorithm.from_checkpoint(checkpoint)

        # checkpoint_info = get_checkpoint_info(checkpoint)
        # state = Algorithm._checkpoint_info_to_algorithm_state(
        #     checkpoint_info=checkpoint_info,
        #     policy_ids=None,
        #     policy_mapping_fn=None,
        #     policies_to_train=None,
        # )
        # algo = Algorithm.from_state(state)

    else:

        algo = create_config(env_config, algo_config).build(logger_creator=logger_creator) 

    save_checkpoint(algo, logdir, 0)

    for i in range(iterations):
        try:

            result = algo.train()
            
        except ValueError as e:
            
            f = open(f'log_error.txt', 'w')
            f.write('An exceptional thing happened - %s' % e)
            f.write(str(traceback.format_exc()))
            f.close()
            quit()
        
        if (i + 1) % 100 == 0:

            save_checkpoint(algo, logdir, i + 1)

    return checkpoint
